envs package

envs.register_env(environment_name: str) Callable[[T], T]

the decorator to register an environment

to add a new class as a registered environment, ensure the class is a derivative type of envs.BaseEnv, add the following code to the top of the file:

@register_env("env_name")
YourEnvClass(BaseEnv):
    ...
Parameters:

environment_name (str) – the name of the registered environment

Returns:

the decorator to register the environment

Return type:

Callable[[_T], _T]

Submodules

envs.BaseEnv module

class envs.BaseEnv.BaseEnv(args: Namespace, device: Optional[str] = None)

Bases: object

the base environment class, should be overridden by specific environments

static add_args(parser: ArgumentParser) None

add arguments to the parser

to add arguments to the parser, modify the method as follows:

@staticmethod
def add_args(parser: argparse.ArgumentParser) -> None:
    parser.add_argument(
        ...
    )

then add arguments to the parser

Parameters:

parser (argparse.ArgumentParser) – the parser to add arguments to

__init__(args: Namespace, device: Optional[str] = None) None

initialize the environment

Parameters:
  • args (argparse.Namespace) – arguments

  • device (Optional[str], optional) – device to run the environment. Defaults to None, which means to use the GPU if available.

initialize_weight() None

initialize the portfolio weight, risk free asset weight, and value

get_asset_num() int

get the number of assets, should be overridden by specific environments

Raises:

NotImplementedError – asset num not implemented

Returns:

the number of assets

Return type:

int

to(device: str) None

move the environment to the given device

Parameters:

device (torch.device) – the device to move to

train_time_range() range

the range of time indices, should be overridden by specific environments

Raises:

NotImplementedError – time_range not implemented

Returns:

the range of time indices

Return type:

range

test_time_range() range

the range of time indices, should be overridden by specific environments

Raises:

NotImplementedError – time_range not implemented

Returns:

the range of time indices

Return type:

range

state_dimension() Dict[str, Size]

the dimension of the state tensors, should be overridden by specific environments

Raises:

NotImplementedError – state_dimension not implemented

Returns:

the dimension of the state tensors

Return type:

Dict[str, torch.Size]

state_tensor_names() List[str]

the names of the state tensors, should be overridden by specific environments

Raises:

NotImplementedError – state_tensor_names not implemented

Returns:

the names of the state tensors

Return type:

List[str]

action_dimension() Size

the dimension of the action the agent can take

Returns:

the dimension of the action the agent can take

Return type:

torch.Size

get_state(state: Optional[Dict[str, Union[Tensor, int]]] = None) Dict[str, Union[Tensor, int]]

get the state tensors at the current time, should be overridden by specific environments

Parameters:

state (Optional[Dict[str, Union[torch.Tensor, int]]], optional) – the state tensors. Defaults to None.

Raises:

NotImplementedError – get_state not implemented

Returns:

the state tensors

Return type:

Dict[str, Union[torch.Tensor, int]]

act(action: Tensor, state: Optional[Dict[str, Union[Tensor, int]]] = None) Tuple[Dict[str, Optional[Union[Tensor, int]]], Tensor, bool]

update the environment with the given action at the given time, should be overridden by specific environments

Parameters:
  • action (torch.Tensor) – the action to perform

  • state (Optional[Dict[str, Union[torch.Tensor, int]]], optional) – the state tensors. Defaults to None.

Raises:

NotImplementedError – act not implemented

Returns:

the new state, reward, and whether the episode is done

Return type:

Tuple[Dict[str, Optional[Union[torch.Tensor, int]]], torch.Tensor, bool]

reset() None

reset the environment, should be overridden by specific environments

Raises:

NotImplementedError – reset not implemented

update(trading_size: Tensor = None, state: Optional[Dict[str, Union[Tensor, int]]] = None, modify_inner_state: Optional[bool] = None) Dict[str, Union[Tensor, int]]

update the environment with the given trading size of each tensor

Parameters:
  • trading_size (torch.Tensor) – the trading size of each asset. Defaults to None.

  • state (Optional[Dict[str, Union[torch.Tensor, int]]], optional) – the state tensors. Defaults to None.

  • modify_inner_state (Optional[bool], optional) – whether to modify the inner state. Defaults to None.

Returns:

the new state

Return type:

Dict[str, Union[torch.Tensor, int]]

_concat_weight(portfolio_weight: Tensor, rf_weight: Tensor) Tensor

concat the portfolio weight and risk free weight, the risk free weight is at the first position

Parameters:
  • portfolio_weight (torch.Tensor) – the portfolio weight

  • rf_weight (torch.Tensor) – the risk free weight

Returns:

the concatenated weight

Return type:

torch.Tensor

_get_price_tensor(time_index: Optional[int] = None) Tensor

get the price tensor at a given time, should be overridden by specific environments

Parameters:

time_index (Optional[int], optional) – the time index to get the price tensor. Defaults to None, which means to get the price tensor at the current time.

Raises:

NotImplementedError – _get_price_tensor not implemented

Returns:

the price tensor

Return type:

torch.Tensor

_get_price_change_ratio_tensor(time_index: Optional[int] = None) tensor

get the price change ratio tensor at a given time, should be overridden by specific environments

Parameters:

time_index (Optional[int], optional) – the time index to get the price change ratio. Defaults to None, which means to get the price change ratio at the current time.

Raises:

NotImplementedError – _get_price_change_ratio_tensor not implemented

Returns:

the price change ratio tensor

Return type:

torch.tensor

_transaction_cost(trading_size: Tensor) Tensor

compute the transaction cost of the trading

transaction_cost = sum(abs(trading_size) for trading_size > 0) * transaction_cost_rate_buy
                 + sum(abs(trading_size) for trading_size < 0) * transaction_cost_rate_sell
                 + count(trading_size != 0) * transaction_cost_base
Parameters:

trading_size (torch.Tensor) – the trading size of each asset

Returns:

the transaction cost

Return type:

torch.Tensor

_get_trading_size_according_to_weight_after_trade(portfolio_weight_before_trade: Tensor, portfolio_weight_after_trade: Tensor, portfolio_value_before_trade: Tensor) Tuple[Tensor, Tensor]

find the trading size according to the weight before and after trading (don’t change day)

Reference: https://arxiv.org/abs/1706.10059 (Section 2.3)

Parameters:
  • portfolio_weight_before_trade (torch.Tensor) – the weight before trading

  • portfolio_weight_after_trade (torch.Tensor) – the weight after trading

  • portfolio_value_before_trade (torch.Tensor) – the portfolio value before trading

Returns:

the trading size and the mu

Return type:

Tuple[torch.Tensor, torch.Tensor]

_get_new_portfolio_weight_and_value(trading_size: Tensor, state: Optional[Dict[str, Union[Tensor, int]]] = None) Tuple[Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor]

get the new portfolio weight and value after trading and transitioning to the next day

Parameters:
  • trading_size (torch.Tensor) – the trading size of each asset

  • state (Optional[Dict[str, Union[torch.Tensor, int]]], optional) – the state tensors. Defaults to None.

Returns:

the new portfolio weight, the new portfolio weight at the next day, the new risk free weight, the new risk free weight at the next day, the new portfolio value, the new portfolio value at the next day, and the portfolio value at the next day with static weight

Return type:

Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]

_cash_shortage(trading_size: Tensor, portfolio_value: Optional[Tensor] = None, rf_weight: Optional[Tensor] = None) bool

assert whether there is cash shortage after trading

Parameters:
  • trading_size (torch.Tensor) – the trading size of each asset

  • portfolio_value (Optional[torch.Tensor], optional) – the portfolio value. Defaults to None.

  • rf_weight (Optional[torch.Tensor], optional) – the risk free weight. Defaults to None.

Returns:

whether there is cash shortage after trading

Return type:

bool

_asset_shortage(trading_size: Tensor, portfolio_weight: Optional[Tensor] = None, portfolio_value: Optional[Tensor] = None) bool

assert whether there is asset shortage after trading

Parameters:
  • trading_size (torch.Tensor) – the trading size of each asset

  • portfolio_weight (torch.Tensor) – the portfolio weight. default to None (Use the current portfolio weight)

  • portfolio_value (torch.Tensor) – the portfolio value. default to None (Use the current portfolio value)

Returns:

whether there is asset shortage after trading

Return type:

bool

select_random_action() Tensor

select a random action, should be overridden by specific environments

Raises:

NotImplementedError – select_random_action not implemented

Returns:

the random action

Return type:

torch.Tensor

get_momentum_action() Tensor

get the momentum action, should be overridden by specific environments

Raises:

NotImplementedError – get_momentum_action not implemented

Returns:

the momentum action

Return type:

torch.Tensor

get_reverse_momentum_action() Tensor

get the reverse momentum action, should be overridden by specific environments

Raises:

NotImplementedError – get_reverse_momentum_action not implemented

Returns:

the reverse momentum action

Return type:

torch.Tensor

envs.BasicContinuousRealDataEnv module

class envs.BasicContinuousRealDataEnv.BasicContinuousRealDataEnv(args: Namespace, data: Data, device: Optional[device] = None)

Bases: BasicRealDataEnv

static add_args(parser: ArgumentParser) None

add arguments to the parser

to add arguments to the parser, modify the method as follows:

@staticmethod
def add_args(parser: argparse.ArgumentParser) -> None:
    parser.add_argument(
        ...
    )

then add arguments to the parser

Parameters:

parser (argparse.ArgumentParser) – the parser to add arguments to

__init__(args: Namespace, data: Data, device: Optional[device] = None) None

initialize the environment

Parameters:
  • args (argparse.Namespace) – arguments

  • data (Data) – data,

  • device (Optional[str], optional) – device to run the environment. Defaults to None, which means to use the GPU if available.

to(device: str) None

move the environment to the given device

Parameters:

device (torch.device) – the device to move to

state_dimension() Dict[str, Size]

the dimension of the state tensors.

Returns:

the dimension of the state tensors

Return type:

Dict[str, torch.Size]

state_tensor_names() List[str]

the names of the state tensors

Returns:

the names of the state tensors

Return type:

List[str]

get_state(state: Optional[Dict[str, Union[Tensor, int]]] = None) Optional[Dict[str, Union[Tensor, int]]]

get the state tensors at the current time.

Parameters:

state (Optional[Dict[str, Union[torch.Tensor, int]]], optional) – the state tensors. Defaults to None.

Returns:

the state tensors

Return type:

Dict[str, torch.Tensor]

act(action: Tensor, state: Optional[Dict[str, Union[Tensor, int]]] = None) Tuple[Dict[str, Optional[Union[Tensor, int]]], Tensor, bool]

perform an action (the trading size)

Parameters:
  • action (torch.Tensor) – the action to perform

  • state (Optional[Dict[str, Union[torch.Tensor, int]]], optional) – the state tensors. Defaults to None.

Returns:

the new state, reward, and whether the episode is done

Return type:

Tuple[Dict[str, Optional[Union[torch.Tensor, int]]], torch.Tensor, bool]

update(action: Tensor = None, state: Optional[Dict[str, Union[Tensor, int]]] = None, modify_inner_state: Optional[bool] = None) Dict[str, Union[Tensor, int]]

update the environment

Parameters:
  • action (torch.Tensor) – the action to perform. Defaults to None.

  • state (Optional[Dict[str, Union[torch.Tensor, int]]], optional) – the state tensors. Defaults to None.

  • modify_inner_state (Optional[bool], optional) – whether to modify the inner state. Defaults to None.

Returns:

the new state

Return type:

Dict[str, Union[torch.Tensor, int]]

select_random_action() Tensor

select a random action

Returns:

the random action

Return type:

torch.Tensor

get_momentum_action() Tensor

get the momentum action

Returns:

the momentum action

Return type:

torch.Tensor

get_reverse_momentum_action() Tensor

get the reverse momentum action

Returns:

the reverse momentum action

Return type:

torch.Tensor

envs.BasicDiscreteRealDataEnv module

class envs.BasicDiscreteRealDataEnv.BasicDiscreteRealDataEnv(args: Namespace, data: Data, device: Optional[device] = None)

Bases: BasicRealDataEnv

static add_args(parser: ArgumentParser) None

add arguments to the parser

to add arguments to the parser, modify the method as follows:

@staticmethod
def add_args(parser: argparse.ArgumentParser) -> None:
    parser.add_argument(
        ...
    )

then add arguments to the parser

Parameters:

parser (argparse.ArgumentParser) – the parser to add arguments to

__init__(args: Namespace, data: Data, device: Optional[device] = None) None

initialize the environment

Parameters:
  • args (argparse.Namespace) – arguments

  • data (Data) – data,

  • device (Optional[str], optional) – device to run the environment. Defaults to None, which means to use the GPU if available.

to(device: str) None

move the environment to the given device

Parameters:

device (torch.device) – the device to move to

find_action_index(action: Tensor) int

given an action, find the index of the action in all_actions

Parameters:

action (torch.Tensor) – the trading decision of each asset

Returns:

the index of the action in all_actions, -1 if not found

Return type:

int

act(action: Tensor, state: Optional[Dict[str, Union[Tensor, int]]] = None) Tuple[Dict[str, Optional[Union[Tensor, int]]], Tensor, bool]

update the environment with the given action at the given time

Parameters:
  • action (torch.tensor) – the action to take

  • state (Optional[Dict[str, Union[torch.Tensor, int]]], optional) – the state tensors. Defaults to None.

Returns:

the new state, reward, and whether the episode is done

Return type:

Tuple[Dict[str, Optional[Union[torch.Tensor, int]]], torch.Tensor, bool]

update(action: Tensor = None, state: Optional[Dict[str, Union[Tensor, int]]] = None, modify_inner_state: Optional[bool] = None) Dict[str, Union[Tensor, int]]

update the environment

Parameters:
  • action (torch.Tensor) – the action to perform. Defaults to None.

  • state (Optional[Dict[str, Union[torch.Tensor, int]]], optional) – the state tensors. Defaults to None.

  • modify_inner_state (Optional[bool], optional) – whether to modify the inner state. Defaults to None.

Returns:

the new state

Return type:

Dict[str, Union[torch.Tensor, int]]

select_random_action() Tensor

select a random action

Returns:

the random action

Return type:

torch.Tensor

possible_actions(state: Dict[str, Tensor] = None) List[Tensor]

get the possible actions

Parameters:

state (Dict[str, torch.Tensor], optional) – the current state. Defaults to None.

Returns:

the possible actions

Return type:

List[torch.Tensor]

get_momentum_action() Tensor

get the momentum action

Returns:

the momentum action

Return type:

torch.Tensor

get_reverse_momentum_action() Tensor

get the reverse momentum action

Returns:

the reverse momentum action

Return type:

torch.Tensor

envs.BasicRealDataEnv module

class envs.BasicRealDataEnv.BasicRealDataEnv(args: Namespace, data: Data, device: Optional[device] = None)

Bases: BaseEnv

static add_args(parser: ArgumentParser) None

add arguments to the parser

to add arguments to the parser, modify the method as follows:

@staticmethod
def add_args(parser: argparse.ArgumentParser) -> None:
    parser.add_argument(
        ...
    )

then add arguments to the parser

Parameters:

parser (argparse.ArgumentParser) – the parser to add arguments to

__init__(args: Namespace, data: Data, device: Optional[device] = None) None

initialize the environment

Parameters:
  • args (argparse.Namespace) – arguments

  • data (Data) – data,

  • device (Optional[str], optional) – device to run the environment. Defaults to None, which means to use the GPU if available.

to(device: str) None

move the environment to the given device

Parameters:

device (torch.device) – the device to move to

get_asset_num() int

get the number of assets, excluding risk-free asset

Returns:

the number of assets

Return type:

int

train_time_range() range

the range of time indices

Returns:

the range of time indices

Return type:

range

test_time_range() range

the range of time indices

Returns:

the range of time indices

Return type:

range

_get_price_change_ratio_tensor(time_index: Optional[int] = None) Tensor

get the price change ratio tensor at a given time

Parameters:

time_index (Optional[int], optional) – the time index to get the price change ratio. Defaults to None, which means to get the price change ratio at the current time.

Returns:

the price change ratio tensor

Return type:

torch.Tensor

_get_price_tensor(time_index: Optional[int] = None) Tensor

get the price tensor at a given time

Parameters:

time_index (Optional[int], optional) – the time index to get the price. Defaults to None, which means to get the price at the current time.

Returns:

the price tensor

Return type:

torch.Tensor

_get_price_tensor_in_window(time_index: int) Tensor

get the price tensor in a window centered at a given time

Parameters:
  • time_index (int) – the time index to get the price

  • window_size (int) – the window size

Returns:

the price tensor in the window

Return type:

torch.Tensor

_get_high_price_tensor_in_window(time_index: int) Tensor

get the high price tensor in a window centered at a given time

Parameters:

time_index (int) – the time index to get the high price

Returns:

the high price tensor in the window

Return type:

torch.Tensor

_get_low_price_tensor_in_window(time_index: int) Tensor

get the low price tensor in a window centered at a given time

Parameters:

time_index (int) – the time index to get the low price

Returns:

the low price tensor in the window

Return type:

torch.Tensor

_get_open_price_tensor_in_window(time_index: int) Tensor

get the open price tensor in a window centered at a given time

Parameters:

time_index (int) – the time index to get the open price

Returns:

the open price tensor in the window

Return type:

torch.Tensor

_get_tensor_in_window(tensor: Tensor, time_index: int) Tensor

get the tensor in a window centered at a given time

Parameters:
  • tensor (torch.Tensor) – the tensor

  • time_index (int) – the time index to get the tensor

Returns:

the tensor in the window

Return type:

torch.Tensor

state_dimension() Dict[str, Size]

the dimension of the state tensors.

Returns:

the dimension of the state tensors

Return type:

Dict[str, torch.Size]

state_tensor_names() List[str]

the names of the state tensors

Returns:

the names of the state tensors

Return type:

List[str]

get_state(state: Optional[Dict[str, Union[Tensor, int]]] = None) Optional[Dict[str, Union[Tensor, int]]]

get the state tensors at the current time.

Parameters:

state (Optional[Dict[str, Union[torch.Tensor, int]]], optional) – the state tensors. Defaults to None.

Returns:

the state tensors

Return type:

Dict[str, Union[torch.Tensor, int]]

reset() None

reset the environment.

envs.ContinuousRealDataEnv1 module

class envs.ContinuousRealDataEnv1.ContinuousRealDataEnv1(args: Namespace, data: Data, device: Optional[device] = None)

Bases: BasicContinuousRealDataEnv

Reference:

https://arxiv.org/abs/1706.10059

static add_args(parser: ArgumentParser) None

add arguments to the parser

to add arguments to the parser, modify the method as follows:

@staticmethod
def add_args(parser: argparse.ArgumentParser) -> None:
    parser.add_argument(
        ...
    )

then add arguments to the parser

Parameters:

parser (argparse.ArgumentParser) – the parser to add arguments to

__init__(args: Namespace, data: Data, device: Optional[device] = None) None

initialize the environment

Parameters:
  • args (argparse.Namespace) – arguments

  • data (Data) – data,

  • device (Optional[str], optional) – device to run the environment. Defaults to None, which means to use the GPU if available.

to(device: str) None

move the environment to the given device

Parameters:

device (torch.device) – the device to move to

state_dimension() Dict[str, Size]

the dimension of the state tensors.

Returns:

the dimension of the state tensors

Return type:

Dict[str, torch.Size]

state_tensor_names() List[str]

the names of the state tensors

Returns:

the names of the state tensors

Return type:

List[str]

get_state(state: Optional[Dict[str, Union[Tensor, int]]] = None) Optional[Dict[str, Union[Tensor, int]]]

get the state tensors at the current time.

Parameters:

state (Optional[Dict[str, torch.Tensor]], optional) – the state tensors. Defaults to None.

Returns:

the state tensors

Return type:

Dict[str, torch.Tensor]

act(action_weight: Tensor, state: Optional[Dict[str, Union[Tensor, int]]] = None) Tuple[Dict[str, Optional[Union[Tensor, int]]], Tensor, bool]

perform an action (the portfolio weight)

Parameters:
  • action_weight (torch.Tensor) – the action to perform

  • state (Optional[Dict[str, Union[torch.Tensor, int]]], optional) – the state tensors. Defaults to None.

Returns:

the new state, reward, and whether the episode is done

Return type:

Tuple[Dict[str, Optional[Union[torch.Tensor, int]]], torch.Tensor, bool]

update(action_weight: Tensor = None, state: Optional[Dict[str, Union[Tensor, int]]] = None, modify_inner_state: Optional[bool] = None) Dict[str, Union[Tensor, int]]

update the environment

Parameters:
  • action_weight (torch.Tensor) – the action to perform, means the weight after trade. Defaults to None.

  • state (Optional[Dict[str, Union[torch.Tensor, int]]], optional) – the state tensors. Defaults to None.

  • modify_inner_state (Optional[bool], optional) – whether to modify the inner state. Defaults to None.

Returns:

the new state

Return type:

Dict[str, Union[torch.Tensor, int]]

select_random_action() Tensor

select a random action

Returns:

the random action

Return type:

torch.Tensor

get_momentum_action() Tensor

get the momentum action

Returns:

the momentum action

Return type:

torch.Tensor

get_reverse_momentum_action() Tensor

get the reverse momentum action

Returns:

the reverse momentum action

Return type:

torch.Tensor

envs.DiscreteRealDataEnv1 module

class envs.DiscreteRealDataEnv1.DiscreteRealDataEnv1(args: Namespace, data: Data, device: Optional[device] = None)

Bases: BasicDiscreteRealDataEnv

Reference:

original paper: https://arxiv.org/abs/1907.03665

static add_args(parser: ArgumentParser) None

add arguments to the parser

to add arguments to the parser, modify the method as follows:

@staticmethod
def add_args(parser: argparse.ArgumentParser) -> None:
    parser.add_argument(
        ...
    )

then add arguments to the parser

Parameters:

parser (argparse.ArgumentParser) – the parser to add arguments to

__init__(args: Namespace, data: Data, device: Optional[device] = None) None

initialize the environment

Parameters:
  • args (argparse.Namespace) – arguments

  • data (Data) – data,

  • device (Optional[str], optional) – device to run the environment. Defaults to None, which means to use the GPU if available.

to(device: str) None

move the environment to the given device

Parameters:

device (torch.device) – the device to move to

sample_distribution_and_set_episode() int

sample a distribution and set the episode accordingly please refer to paper https://arxiv.org/abs/1907.03665 for more details

Returns:

the episode index

Return type:

int

set_episode(episode: int) None

set the episode given the episode index

Parameters:

episode (int) – the episode index

set_episode_for_testing() None

special function to set the episode for testing

train_time_range() range

the range of time indices

Returns:

the range of time indices

Return type:

range

test_time_range() range

the range of time indices

Returns:

the range of time indices

Return type:

range

pretrain_train_time_range(shuffle: bool = True) List

the list of time indices for pretraining

Parameters:

shuffle (bool, optional) – whether to shuffle. Defaults to True.

Returns:

the list of time indices

Return type:

List

pretrain_eval_time_range() range

the list of time indices for pretraining evaluation

Returns:

the list of time indices

Return type:

range

state_dimension() Dict[str, Size]

the dimension of the state tensors, including Xt_Matrix and Portfolio_Weight

Returns:

the dimension of the state tensors

Return type:

Dict[str, torch.Size]

state_tensor_names()

the names of the state tensors, including Xt_Matrix and Portfolio_Weight

Returns:

the names of the state tensors

Return type:

List[str]

get_state(state: Optional[Dict[str, Union[Tensor, int]]] = None) Optional[Dict[str, Union[Tensor, int]]]

get the state tensors at the current time, including Xt_Matrix and Portfolio_Weight

Parameters:

state (Optional[Dict[str, Union[torch.Tensor, int]]], optional) – the state tensors. Defaults to None.

Returns:

the state tensors

Return type:

Dict[str, torch.Tensor]

_get_Xt_state(time_index: Optional[int] = None) Tensor

get the Xt state tensor at a given time

Parameters:

time_index (Optional[int], optional) – the time_index. Defaults to None, which means to get the Xt state tensor at the current time.

Returns:

the Xt state tensor

Return type:

torch.Tensor

get_pretrain_input_and_target(time_index: int) Optional[Tuple[Dict[str, Tensor], Tensor]]

get the Xt state tensor and the pretrain target at a given time

Parameters:

time_index (int) – the time index

Returns:

the input state and the pretrain target

Return type:

Optional[Tuple[Dict[str, torch.Tensor], torch.Tensor]]

act(action_idx: int, state: Optional[Dict[str, Union[Tensor, int]]] = None) Tuple[Dict[str, Optional[Union[Tensor, int]]], Tensor, bool]

update the environment with the given action at the given time

Parameters:
  • action_idx (int) – the id of the action to take

  • state (Optional[Dict[str, Union[torch.Tensor, int]]], optional) – the state tensors. Defaults to None.

Raises:

ValueError – action not valid

Returns:

the new state, reward, and whether the episode is done

Return type:

Tuple[Dict[str, Optional[Union[torch.Tensor, int]]], torch.Tensor, bool]

update(action: int = None, state: Optional[Dict[str, Union[Tensor, int]]] = None, modify_inner_state: Optional[bool] = None) Dict[str, Union[Tensor, int]]

update the environment with the given action index

Parameters:
  • action (int) – the id of the action to take. Defaults to None.

  • state (Optional[Dict[str, Union[torch.Tensor, int]]], optional) – the state tensors. Defaults to None.

  • modify_inner_state (Optional[bool], optional) – whether to modify the inner state. Defaults to None.

Raises:

ValueError – the action is not valid

Returns:

the new state

Return type:

Dict[str, Union[torch.Tensor, int]]

reset() None

reset the environment to the initial state

_cash_shortage(action: Tensor, portfolio_value: Optional[Tensor] = None, rf_weight: Optional[Tensor] = None) bool

assert whether there is cash shortage after trading

Parameters:
  • action (torch.Tensor) – the trading decision of each asset

  • portfolio_value (Optional[torch.Tensor], optional) – the portfolio value. Defaults to None.

  • rf_weight (Optional[torch.Tensor], optional) – the risk free weight. Defaults to None.

Returns:

whether there is cash shortage after trading

Return type:

bool

_asset_shortage(action: Tensor, portfolio_weight: Optional[Tensor] = None, portfolio_value: Optional[Tensor] = None) bool

assert whether there is asset shortage after trading

Parameters:
  • action (torch.Tensor) – the trading decision of each asset

  • portfolio_weight (Optional[torch.Tensor], optional) – the portfolio weight. Defaults to None.

  • portfolio_value (Optional[torch.Tensor], optional) – the portfolio value. Defaults to None.

Returns:

whether there is asset shortage after trading

Return type:

bool

_action_validity(action: Tensor, portfolio_weight: Optional[Tensor] = None, portfolio_value: Optional[Tensor] = None, rf_weight: Optional[Tensor] = None) bool

assert whether the action is valid

Parameters:
  • action (torch.Tensor) – the trading decision of each asset

  • portfolio_weight (Optional[torch.Tensor], optional) – the portfolio weight. Defaults to None.

  • portfolio_value (Optional[torch.Tensor], optional) – the portfolio value. Defaults to None.

  • rf_weight (Optional[torch.Tensor], optional) – the risk free weight. Defaults to None.

Returns:

whether the action is valid

Return type:

bool

possible_actions() Tensor

get all possible action indexes

Returns:

all possible action indexes

Return type:

torch.Tensor

action_mapping(action_index: int, Q_Values: Tensor, state: Optional[Dict[str, Tensor]] = None) int

perform action mapping based on the Q values

Parameters:
  • action_index (int) – the index of the action to map

  • Q_Values (torch.Tensor) – the Q values of all actions

  • state (Optional[Dict[str, torch.Tensor]], optional) – the state tensors. Defaults to None.

Raises:

ValueError – action not valid

Returns:

the index of the mapped action

Return type:

int

_action_mapping_rule1(action: Tensor, Q_Values: Tensor, portfolio_value: Optional[Tensor] = None, rf_weight: Optional[Tensor] = None) int

action mapping rule 1: if there is cash shortage, find the subset action with the highest Q value

Parameters:
  • action (torch.Tensor) – the trading decision of each asset

  • Q_Values (torch.Tensor) – the Q values of all actions

  • portfolio_value (Optional[torch.Tensor], optional) – the portfolio value. Defaults to None.

  • rf_weight (Optional[torch.Tensor], optional) – the risk free weight. Defaults to None.

Returns:

the index of the mapped action

Return type:

int

_action_mapping_rule2(action: Tensor, portfolio_weight: Tensor, portfolio_value: Tensor) int

the action mapping rule 2: if there is asset shortage, don’t trade the asset with shortage

Parameters:
  • action (torch.Tensor) – the trading decision of each asset

  • portfolio_weight (torch.Tensor) – the portfolio weight

  • portfolio_value (torch.Tensor) – the portfolio value

Returns:

the index of the mapped action

Return type:

int

select_random_action() int

select a random valid action, return its index

Returns:

the index of the selected action

Return type:

int

get_momentum_action() int

get the momentum action

Returns:

the momentum action

Return type:

int

get_reverse_momentum_action() int

get the reverse momentum action

Returns:

the reverse momentum action

Return type:

int

envs.DiscreteRealDataEnv2 module

class envs.DiscreteRealDataEnv2.DiscreteRealDataEnv2(args: Namespace, data: Data, device: Optional[device] = None)

Bases: BasicDiscreteRealDataEnv

static add_args(parser: ArgumentParser) None

add arguments to the parser

to add arguments to the parser, modify the method as follows:

@staticmethod
def add_args(parser: argparse.ArgumentParser) -> None:
    parser.add_argument(
        ...
    )

then add arguments to the parser

Parameters:

parser (argparse.ArgumentParser) – the parser to add arguments to

__init__(args: Namespace, data: Data, device: Optional[device] = None) None

initialize the environment

Parameters:
  • args (argparse.Namespace) – arguments

  • data (Data) – data,

  • device (Optional[str], optional) – device to run the environment. Defaults to None, which means to use the GPU if available.

action_is_valid(action: Tensor, state: Optional[Dict[str, Tensor]] = None) bool

check if the action is valid

Parameters:
  • action (torch.Tensor) – the action

  • state (Optional[Dict[str, torch.Tensor]], optional) – the current state. Defaults to None.

Returns:

whether the action is valid

Return type:

bool

possible_actions(state: Dict[str, Tensor] = None) List[Tensor]

get the possible actions

Parameters:

state (Dict[str, torch.Tensor], optional) – the current state. Defaults to None.

Returns:

the possible actions

Return type:

List[torch.Tensor]

get_momentum_action() Tensor

get the momentum action

Returns:

the momentum action

Return type:

torch.Tensor

get_reverse_momentum_action() Tensor

get the reverse momentum action

Returns:

the reverse momentum action

Return type:

torch.Tensor